// RUN: tf-opt %s -tf-replicated-clustering-bridge-v2 --mlir-print-ir-before-all --mlir-print-ir-after-all | FileCheck %s

// CHECK-LABEL: func.func @main
// CHECK  %0 = "tf.TPUCompilationResult"() {_tpu_compilation_status = "cluster__train_helper", device = ""} : () -> tensor<!tf_type.string>
// CHECK  %1 = "tf.ReadVariableOp"(%arg3) : (tensor<*x!tf_type.resource<tensor<128x1024xf32>>>) -> tensor<128x1024xf32>
// CHECK  %2 = "tf.ReadVariableOp"(%arg4) : (tensor<*x!tf_type.resource<tensor<1024xf32>>>) -> tensor<1024xf32>
// CHECK  %cst = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<f32>
// CHECK  %cst_0 = "tf.Const"() <{value = dense<[128, 1024]> : tensor<2xi64>}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<2xi64>
// CHECK  %3 = "tf.Fill"(%cst_0, %cst) {_ici_weight_distribution_mlir_bridge_marker = true} : (tensor<2xi64>, tensor<f32>) -> tensor<128x1024xf32>
// CHECK  %cst_1 = "tf.Const"() <{value = dense<0.000000e+00> : tensor<f32>}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<f32>
// CHECK  %cst_2 = "tf.Const"() <{value = dense<1024> : tensor<1xi64>}> {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<1xi64>
// CHECK  %4 = "tf.Fill"(%cst_2, %cst_1) {_ici_weight_distribution_mlir_bridge_marker = true} : (tensor<1xi64>, tensor<f32>) -> tensor<1024xf32>
// CHECK  %5:2 = tf_device.replicate([%1, %3] as %arg22: tensor<128x1024xf32>, [%2, %4] as %arg23: tensor<1024xf32>) {n = 2 : i32} {
// CHECK    %8 = "tf_device.launch"() <{device = "TPU_REPLICATED_HOST_0"}> ({
// CHECK      %11 = "tf.Identity"(%arg22) {_ici_weight_distribution_mlir_bridge_marker = true} : (tensor<128x1024xf32>) -> tensor<128x1024xf32>
// CHECK      tf_device.return %11 : tensor<128x1024xf32>
// CHECK    }) {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<128x1024xf32>
// CHECK    %9 = "tf_device.launch"() <{device = "TPU_REPLICATED_HOST_0"}> ({
// CHECK      %11 = "tf.Identity"(%arg23) {_ici_weight_distribution_mlir_bridge_marker = true} : (tensor<1024xf32>) -> tensor<1024xf32>
// CHECK      tf_device.return %11 : tensor<1024xf32>
// CHECK    }) {_ici_weight_distribution_mlir_bridge_marker = true} : () -> tensor<1024xf32>
// CHECK    %10 = "tf_device.cluster_func"(%8, %9) <{func = @_func}> {_dynamic_arg_index = [], _has_manual_control_dependencies = true, _replication_info = "cluster__train_helper", _xla_compile_device_type = "TPU", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0], host_compute_core = [], input_sharding_configuration = ["\08\03\1A\01\04\22\04\00\01\02\03", ""], num_cores_per_replica = 4 : i64, output_sharding_configuration = [""], padding_map = [], step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\02\02\02\01\10\04\18\02\22 \00\00\00\00\00\01\00\00\01\00\00\00\01\01\00\00\00\00\01\00\00\01\01\00\01\00\01\00\01\01\01\00*\02\08\01", tpu_compile_options_proto = "", use_spmd_for_xla_partitioning = true, use_tpu = true} : (tensor<128x1024xf32>, tensor<1024xf32>) -> tensor<*xf32>
// CHECK    tf_device.return %10 : tensor<*xf32>
// CHECK  }
// CHECK  %6 = "tf.ReadVariableOp"(%arg2) {device = ""} : (tensor<*x!tf_type.resource<tensor<i64>>>) -> tensor<i64>
// CHECK  %7 = "tf.Identity"(%6) {device = ""} : (tensor<i64>) -> tensor<i64>
// CHECK  return %7 : tensor<i64>

// CHECK-LABEL:  func.func private @_func(%arg0: tensor<128x1024xf32> {mhlo.sharding = "\08\03\1A\01\04\22\04\00\01\02\03"}, %arg1: tensor<1024xf32> {mhlo.sharding = ""}) -> (tensor<*xf32> {mhlo.sharding = ""}) {
// CHECK  %cst = "tf.Const"() <{value = dense<[[0, 1]]> : tensor<1x2xi32>}> : () -> tensor<1x2xi32>
// CHECK  %0 = "tf.XlaAllReduce"(%arg0, %cst) <{mode = "CrossReplica", reduce_op = "Add"}> : (tensor<128x1024xf32>, tensor<1x2xi32>) -> tensor<128x1024xf32>
// CHECK  %cst_0 = "tf.Const"() <{value = dense<[[0, 1]]> : tensor<1x2xi32>}> : () -> tensor<1x2xi32>
// CHECK  %1 = "tf.XlaAllReduce"(%arg1, %cst_0) <{mode = "CrossReplica", reduce_op = "Add"}> : (tensor<1024xf32>, tensor<1x2xi32>) -> tensor<1024xf32>
// CHECK  %2 = "tf.XlaSharding"(%0) <{_XlaSharding = "\08\03\1A\01\04\22\04\00\01\02\03", sharding = "\08\03\1A\01\04\22\04\00\01\02\03"}> {unspecified_dims = []} : (tensor<128x1024xf32>) -> tensor<128x1024xf32>
// CHECK  %3 = "tf.MatMul"(%2, %1) : (tensor<128x1024xf32>, tensor<1024xf32>) -> tensor<*xf32>
// CHECK  return %3 : tensor<*xf32>
// CHECK  }

module attributes {tf.devices = {"/job:tpu_host_worker/replica:0/task:0/device:CPU:0", "/job:tpu_host_worker/replica:0/task:0/device:TPU:0", "/job:tpu_host_worker/replica:0/task:0/device:TPU:1", "/job:tpu_host_worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:1/device:CPU:0", "/job:tpu_host_worker/replica:0/task:1/device:TPU:0", "/job:tpu_host_worker/replica:0/task:1/device:TPU:1", "/job:tpu_host_worker/replica:0/task:1/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:2/device:CPU:0", "/job:tpu_host_worker/replica:0/task:2/device:TPU:0", "/job:tpu_host_worker/replica:0/task:2/device:TPU:1", "/job:tpu_host_worker/replica:0/task:2/device:TPU_SYSTEM:0", "/job:tpu_host_worker/replica:0/task:3/device:CPU:0", "/job:tpu_host_worker/replica:0/task:3/device:TPU:0", "/job:tpu_host_worker/replica:0/task:3/device:TPU:1", "/job:tpu_host_worker/replica:0/task:3/device:TPU_SYSTEM:0"}, tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1857 : i32}} {
  func.func @main(%arg0: tensor<i32> {tf._user_specified_name = "steps", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg1: tensor<*x!tf_type.resource<tensor<i64>>> {tf._user_specified_name = "899", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg2: tensor<*x!tf_type.resource<tensor<i64>>> {tf._user_specified_name = "901", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg3: tensor<*x!tf_type.resource<tensor<128x1024xf32>>> {tf._user_specified_name = "903", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg4: tensor<*x!tf_type.resource<tensor<1024xf32>>> {tf._user_specified_name = "905", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg5: tensor<*x!tf_type.resource<tensor<1024x1xf32>>> {tf._user_specified_name = "907", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg6: tensor<*x!tf_type.resource<tensor<i64>>> {tf._user_specified_name = "909", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg7: tensor<*x!tf_type.resource<tensor<25001x64xf32>>> {tf._user_specified_name = "911", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg8: tensor<*x!tf_type.resource<tensor<25001x64xf32>>> {tf._user_specified_name = "913", tf.device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"}, %arg9: tensor<*x!tf_type.resource<tensor<25001x64xf32>>> {tf._user_specified_name = "915", tf.device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"}, %arg10: tensor<*x!tf_type.resource<tensor<25001x64xf32>>> {tf._user_specified_name = "917", tf.device = "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"}, %arg11: tensor<*x!tf_type.resource<tensor<25001x32xf32>>> {tf._user_specified_name = "919", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg12: tensor<*x!tf_type.resource<tensor<25001x32xf32>>> {tf._user_specified_name = "921", tf.device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"}, %arg13: tensor<*x!tf_type.resource<tensor<25001x32xf32>>> {tf._user_specified_name = "923", tf.device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"}, %arg14: tensor<*x!tf_type.resource<tensor<25001x32xf32>>> {tf._user_specified_name = "925", tf.device = "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"}, %arg15: tensor<*x!tf_type.resource<tensor<6x32xf32>>> {tf._user_specified_name = "927", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg16: tensor<*x!tf_type.resource<tensor<6x32xf32>>> {tf._user_specified_name = "929", tf.device = "/job:tpu_host_worker/replica:0/task:1/device:CPU:0"}, %arg17: tensor<*x!tf_type.resource<tensor<6x32xf32>>> {tf._user_specified_name = "931", tf.device = "/job:tpu_host_worker/replica:0/task:2/device:CPU:0"}, %arg18: tensor<*x!tf_type.resource<tensor<6x32xf32>>> {tf._user_specified_name = "933", tf.device = "/job:tpu_host_worker/replica:0/task:3/device:CPU:0"}, %arg19: tensor<*x!tf_type.resource<tensor<128x1024xf32>>> {tf._user_specified_name = "935", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg20: tensor<*x!tf_type.resource<tensor<1024xf32>>> {tf._user_specified_name = "937", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}, %arg21: tensor<*x!tf_type.resource<tensor<1024x1xf32>>> {tf._user_specified_name = "939", tf.device = "/job:tpu_host_worker/replica:0/task:0/device:CPU:0"}) -> tensor<*xi64> attributes {allow_soft_placement = false, tf.entry_function = {control_outputs = "", inputs = "steps,unknown,unknown_0,unknown_1,unknown_2,unknown_3,unknown_4,unknown_5,unknown_6,unknown_7,unknown_8,unknown_9,unknown_10,unknown_11,unknown_12,unknown_13,unknown_14,unknown_15,unknown_16,unknown_17,unknown_18,unknown_19", outputs = "statefulpartitionedcall_RetVal"}} {
    %0 = tf_executor.graph {
      %outputs, %control = tf_executor.island wraps "tf.Const"() <{value = dense<false> : tensor<i1>}> {device = ""} : () -> tensor<i1>
      %outputs_0, %control_1 = tf_executor.island wraps "tf.StatefulPartitionedCall"(%arg0, %outputs, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21) <{config = "", config_proto = "\0A\07\0A\03CPU\10\01\0A\07\0A\03GPU\10\002\02J\00\82\01\14h\01\88\01\01\BA\01\0C\0A\0Astandalone", executor_type = "", f = @__inference__train_helper_8510}> {_collective_manager_ids = [], _read_only_resource_inputs = [8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], device = ""} : (tensor<i32>, tensor<i1>, tensor<*x!tf_type.resource<tensor<i64>>>, tensor<*x!tf_type.resource<tensor<i64>>>, tensor<*x!tf_type.resource<tensor<128x1024xf32>>>, tensor<*x!tf_type.resource<tensor<1024xf32>>>, tensor<*x!tf_type.resource<tensor<1024x1xf32>>>, tensor<*x!tf_type.resource<tensor<i64>>>, tensor<*x!tf_type.resource<tensor<25001x64xf32>>>, tensor<*x!tf_type.resource<tensor<25001x64xf32>>>, tensor<*x!tf_type.resource<tensor<25001x64xf32>>>, tensor<*x!tf_type.resource<tensor<25001x64xf32>>>, tensor<*x!tf_type.resource<tensor<25001x32xf32>>>, tensor<*x!tf_type.resource<tensor<25001x32xf32>>>, tensor<*x!tf_type.resource<tensor<25001x32xf32>>>, tensor<*x!tf_type.resource<tensor<25001x32xf32>>>, tensor<*x!tf_type.resource<tensor<6x32xf32>>>, tensor<*x!tf_type.resource<tensor<6x32xf32>>>, tensor<*x!tf_type.resource<tensor<6x32xf32>>>, tensor<*x!tf_type.resource<tensor<6x32xf32>>>, tensor<*x!tf_type.resource<tensor<128x1024xf32>>>, tensor<*x!tf_type.resource<tensor<1024xf32>>>, tensor<*x!tf_type.resource<tensor<1024x1xf32>>>) -> tensor<*xi64>
      tf_executor.fetch %outputs_0 : tensor<*xi64>
    }
    return %0 : tensor<*xi64>
  }
  func.func private @__inference__train_helper_8510(%arg0: tensor<i32> {tf._user_specified_name = "steps"}, %arg1: tensor<i1> {tf._user_specified_name = "include_summaries"}, %arg2: tensor<!tf_type.resource> {tf._user_specified_name = "input_5"}, %arg3: tensor<!tf_type.resource> {tf._user_specified_name = "input_6"}, %arg4: tensor<!tf_type.resource> {tf._user_specified_name = "input_7"}, %arg5: tensor<!tf_type.resource> {tf._user_specified_name = "input_8"}, %arg6: tensor<!tf_type.resource> {tf._user_specified_name = "input_9"}, %arg7: tensor<!tf_type.resource> {tf._user_specified_name = "input_10"}, %arg8: tensor<!tf_type.resource> {tf._user_specified_name = "input_11"}, %arg9: tensor<!tf_type.resource> {tf._user_specified_name = "input_12"}, %arg10: tensor<!tf_type.resource> {tf._user_specified_name = "input_13"}, %arg11: tensor<!tf_type.resource> {tf._user_specified_name = "input_14"}, %arg12: tensor<!tf_type.resource> {tf._user_specified_name = "input_15"}, %arg13: tensor<!tf_type.resource> {tf._user_specified_name = "input_16"}, %arg14: tensor<!tf_type.resource> {tf._user_specified_name = "input_17"}, %arg15: tensor<!tf_type.resource> {tf._user_specified_name = "input_18"}, %arg16: tensor<!tf_type.resource> {tf._user_specified_name = "input_19"}, %arg17: tensor<!tf_type.resource> {tf._user_specified_name = "input_20"}, %arg18: tensor<!tf_type.resource> {tf._user_specified_name = "input_21"}, %arg19: tensor<!tf_type.resource> {tf._user_specified_name = "input_22"}, %arg20: tensor<!tf_type.resource> {tf._user_specified_name = "input_23"}, %arg21: tensor<!tf_type.resource> {tf._user_specified_name = "input_24"}, %arg22: tensor<!tf_type.resource> {tf._user_specified_name = "input_25"}) -> tensor<*xi64> attributes {tf._construction_context = "kEagerRuntime", tf._disable_acd = true, tf._input_shapes = [#tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>], tf.signature.is_stateful} {
    %0 = tf_executor.graph {
      %control = tf_executor.island wraps "tf.NoOp"() {_pivot_for_cluster = "cluster__train_helper", device = ""} : () -> ()
      %control_0 = tf_executor.island(%control) wraps "tf.NoOp"() {_has_manual_control_dependencies = true, _tpu_replicate = "cluster__train_helper", device = ""} : () -> ()
      %control_1 = tf_executor.island(%control) wraps "tf.TPUReplicateMetadata"() <{allow_soft_placement = false, computation_shape = [], device_assignment = [0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 0], host_compute_core = [], num_cores_per_replica = 4 : i64, num_replicas = 2 : i64, padding_map = [], step_marker_location = "STEP_MARK_AT_TOP_LEVEL_WHILE_LOOP", topology = "\0A\04\02\02\02\01\10\04\18\02\22 \00\00\00\00\00\01\00\00\01\00\00\00\01\01\00\00\00\00\01\00\00\01\01\00\01\00\01\00\01\01\01\00*\02\08\01", tpu_compile_options_proto = "", use_spmd_for_xla_partitioning = true, use_tpu = true}> {_has_manual_control_dependencies = true, _tpu_replicate = "cluster__train_helper", device = ""} : () -> ()
      %outputs, %control_2 = tf_executor.island(%control_1) wraps "tf.Const"() <{value = dense<0> : tensor<i32>}> {_tpu_replicate = "cluster__train_helper", device = ""} : () -> tensor<i32>
      %outputs_3, %control_4 = tf_executor.island(%control_1) wraps "tf.Const"() <{value = dense<0> : tensor<i32>}> {_tpu_replicate = "cluster__train_helper", device = ""} : () -> tensor<i32>
      %outputs_5, %control_6 = tf_executor.island(%control_1) wraps "tf.TPUCompilationResult"() {_tpu_compilation_status = "cluster__train_helper", device = ""} : () -> tensor<!tf_type.string>
      %outputs_7, %control_8 = tf_executor.island(%control_1) wraps "tf.Const"() <{value = dense<0> : tensor<i32>}> {_tpu_replicate = "cluster__train_helper", device = ""} : () -> tensor<i32>
      %outputs_9, %control_10 = tf_executor.island(%control_1) wraps "tf.Const"() <{value = dense<-1> : tensor<i32>}> {_tpu_replicate = "cluster__train_helper", device = ""} : () -> tensor<i32>
      %outputs_8, %control_9 = tf_executor.island wraps "tf.ReadVariableOp"(%arg4) {_tpu_replicate = "cluster__train_helper", device = ""} : (tensor<!tf_type.resource>) -> tensor<*xf32>
      %outputs_10, %control_11 = tf_executor.island wraps "tf.ReadVariableOp"(%arg5) {_tpu_replicate = "cluster__train_helper", device = ""} : (tensor<!tf_type.resource>) -> tensor<*xf32>
      %outputs_12, %control_382 = tf_executor.island wraps "tf.XlaSharding"(%outputs_8) <{_XlaSharding = "\08\03\1A\01\04\22\04\00\01\02\03", sharding = "\08\03\1A\01\04\22\04\00\01\02\03"}> {_tpu_replicate = "cluster__train_helper", device = "", unspecified_dims = []} : (tensor<*xf32>) -> tensor<*xf32>
      %outputs_11, %control_12 = tf_executor.island wraps "tf.MatMul"(%outputs_12, %outputs_10) {_tpu_replicate = "cluster__train_helper", device = ""} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32>
      %outputs_13, %control_14 = tf_executor.island wraps "tf.Identity"(%outputs_11) {_tpu_output_identity = true, _tpu_replicate = "cluster__train_helper", device = ""} : (tensor<*xf32>) -> tensor<*xf32>
      %outputs_15:2, %control_16 = tf_executor.island wraps "tf.TPUReplicatedOutput"(%outputs_13) {device = ""} : (tensor<*xf32>) -> (tensor<*xf32>, tensor<*xf32>)
      %outputs_17, %control_18 = tf_executor.island(%control_0) wraps "tf.Identity"(%outputs_15#0) {_has_manual_control_dependencies = true, device = ""} : (tensor<*xf32>) -> tensor<*xf32>
      %control_19 = tf_executor.island(%control_18) wraps "tf.NoOp"() {_has_manual_control_dependencies = true, device = ""} : () -> ()
      %outputs_20, %control_21 = tf_executor.island(%control_18, %control_19) wraps "tf.ReadVariableOp"(%arg3) {device = ""} : (tensor<!tf_type.resource>) -> tensor<*xi64>
      %outputs_22, %control_23 = tf_executor.island wraps "tf.Identity"(%outputs_20) {device = ""} : (tensor<*xi64>) -> tensor<*xi64>
      %outputs_24, %control_25 = tf_executor.island(%control_0) wraps "tf.Identity"(%outputs_15#1) {device = ""} : (tensor<*xf32>) -> tensor<*xf32>
      tf_executor.fetch %outputs_22 : tensor<*xi64>
    }
    return %0 : tensor<*xi64>
  }
}